package com.github.hgkmail.hello.leetcode101.pointer.tree.bst;

import com.github.hgkmail.hello.leetcode101.base.CommonUtil;
import com.github.hgkmail.hello.leetcode101.base.TreeNode;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

//难点：中序遍历prev指针的设置，prev是[全局变量]，不是局部变量！
public class LC99RecoverBinarySearchTree {
    //全局变量放这里，不要再弄单元素的list了。。。
    TreeNode prev;

    //注意是中序遍历
    public void inorder(TreeNode root, TreeNode[] mistakes) {
        if (root==null) {
            return;
        }
        if (root.left!=null) {
            inorder(root.left, mistakes);
        }
        //这个画图可以理解
        if (prev!=null && root.val < prev.val) {
            if (mistakes[0]==null) {
                mistakes[0]=prev;
                mistakes[1]=root;
            } else {
                mistakes[1]=root;
            }
        }
        //prev就是右子树的根节点。。。
        prev=root;
        if (root.right!=null) {
            inorder(root.right, mistakes);
        }
    }

    public void recoverTree(TreeNode root) {
        //BST中序遍历是递增数列，交换2个节点，会产生2处变小的地方
        //需要设置一个[全局的]prev指针，来比较大小，判断是变大还是变小
        TreeNode[] mistakes=new TreeNode[2];

        inorder(root, mistakes);

        if (mistakes[0]!=null && mistakes[1]!=null) {
            int temp=mistakes[0].val;
            mistakes[0].val=mistakes[1].val;
            mistakes[1].val=temp;
        }
    }

    public void inorderDfs(TreeNode root, List<TreeNode> inorderRes) {
        if (root==null) {
            return;
        }
        inorderDfs(root.left, inorderRes);
        inorderRes.add(root);
        inorderDfs(root.right, inorderRes);
    }

    //通过中序遍历把BST转成列表，找出2处变小的地方
    public void recoverTree2(TreeNode root) {
        List<TreeNode> inorderRes=new ArrayList<>();
        inorderDfs(root, inorderRes);

        TreeNode mistake1=null, mistake2=null;
        int sz=inorderRes.size();
        for (int i = 1; i < sz; i++) {
            if (inorderRes.get(i).val<inorderRes.get(i-1).val) {
                if (mistake1==null) {
                    mistake1=inorderRes.get(i-1);
                }
                mistake2=inorderRes.get(i);
            }
        }
        if (mistake1!=null && mistake2!=null) {
            int temp=mistake1.val;
            mistake1.val=mistake2.val;
            mistake2.val=temp;
        }

    }

    public void inorderVisit(TreeNode root) {
        if (root==null) {
            return;
        }
        inorderVisit(root.left);
        System.out.print(root.val+" ");
        inorderVisit(root.right);
    }

    public static void main(String[] args) {
        TreeNode root= CommonUtil.deserializeBinaryTree("1,#,2,3,#,#,#");
//        new LC99RecoverBinarySearchTree().inorderVisit(root);
        new LC99RecoverBinarySearchTree().recoverTree2(root);
        System.out.println(CommonUtil.serializeBinaryTree(root));
    }
}
